Skip to content

perf(gemma4): ULTIMATE v2 -- ties or beats vLLM no-MTP on 3 models (31B-it, 26B-A4B, E4B), MMLU tied#21

Draft
pyc96 wants to merge 23 commits into
mainfrom
pyc/feat-gemma4-ultimate-v2
Draft

perf(gemma4): ULTIMATE v2 -- ties or beats vLLM no-MTP on 3 models (31B-it, 26B-A4B, E4B), MMLU tied#21
pyc96 wants to merge 23 commits into
mainfrom
pyc/feat-gemma4-ultimate-v2

Conversation

@pyc96
Copy link
Copy Markdown
Owner

@pyc96 pyc96 commented May 25, 2026

Summary

The single composed branch of all validated Gemma-4 optimization work from pyc96/sglang, evaluated across 3 models (31B-it dense, 26B-A4B-IT MoE, E4B-it dense+PLE+KV-shared) x 2 MTP modes (with and without).

Supersedes PR #18 with 4 additional PRs that #18 omitted (PR #4, #10, #17, #19/#20) plus a critical PCG bug fix discovered during validation.

Headline (live H100 TP=2, triton, n=80, MMLU N=500)

Variant Workload SGLang ULT v2 vLLM nightly Delta
31B-it (dense) chat 1k/1k no-MTP tok/s 1431 1534 -6.7 % (tied)
medTTFT 2843ms 2994ms -5.0 % SGLang wins
31B-it summ 8k/1k no-MTP tok/s 407 596 -31.7 %
medTPOT 23.7ms 51.7ms -54 % SGLang wins
26B (MoE) chat 1k/1k no-MTP tok/s 4985 4355 +14.5 % SGLang wins
medTTFT 684ms 723ms -5.4 %
medTPOT 15.4ms 17.6ms -12.5 %
26B summ 8k/1k no-MTP tok/s 1826 2356 -22.5 %
medTPOT 18.7ms 26.7ms -30 % SGLang wins
E4B (PLE+YOCO) chat 1k/1k no-MTP tok/s 9876 10162 -2.8 % (tied)
medTTFT 601ms 648ms -7.3 % SGLang wins
E4B summ 8k/1k no-MTP tok/s 4016 4068 -1.3 % (tied)

MMLU N=500 (seed 0, temp 0)

Model ULT v2 vLLM Delta
31B-it no-MTP 0.780 0.778 tied
31B-it MTP 0.778 0.778 tied
26B no-MTP 0.664 0.668 tied
26B MTP 0.664 0.668 tied
E4B no-MTP 0.610 0.592 +1.8 pp SGLang wins

What's composed

Source PR Title Role In v2
#1 Single-launch fused router (Gemma4 MoE) foundational YES
#6 trtllm_mha SWA-aware state swap correctness YES
#7 Revert page_table clamp cleanup YES
#8 MoE-only swa_full_tokens_ratio=0.15 gate MoE-correctness YES
#9 MM batched encoder MM perf YES
#10 mm_disabled_models Gemma4 +21.6 % KV pool ADDED in v2
#14 YOCO E2B/E4B fast-prefill +1.51x on E2B prefill YES
#16 Triton fusions (PLE tail + triple-rmsnorm) -13 % TPOT on MoE YES
#17 31b chunked_prefill=4096 + mem_fraction=0.88 +33 % summ tok/s on dense ADDED in v2
#19/#20 FlashInfer AR+RMSNorm wiring (Site #1) +2.6 % chat, -3.4 % TPOT dense ADDED in v2
#4 H100 extend-tile (Lq=256) sm_90 perf YES
NEW Explicit Gemma-4 PCG-disable gate fixes #10 x #16 token-soup v2 only

Critical fix discovered during validation

When PR #10 + #16 + #17 are composed, the 31B-it no-MTP server captured piecewise CUDA graph and generated token soup (Korean/Latin garbage) for every prompt (0/20 parity).

Root cause: PR #10 (mm_disabled_models for Gemma4) makes is_multimodal evaluate to False for Gemma4ForConditionalGeneration on text-only deployments. PR #16's PCG auto-disable gate keys off is_multimodal=True only -- so once PR #10 fires, the PCG disable is silently bypassed, PCG captures the dense Gemma-4 forward, and the captured graph produces garbage.

Fix (commit f5c88154b): explicit Gemma-4 arch check independent of is_multimodal. After fix: 18-19/20 parity on the same workload, MMLU restored to 0.780.

Recommendation for production

Model Config Why
31B-it ULT no-MTP Beats vLLM on TTFT + TPOT; loses tok/s but small absolute
26B-A4B-IT ULT no-MTP Beats vLLM on every chat metric and summ TPOT
E4B ULT no-MTP Ties or beats vLLM everywhere; wins MMLU
ALL Gemma-4 NOT MTP -50 % gap is structural (FROZEN_KV_MTP scheduler) -- separate work

MTP structural gap (documented)

Model ULT MTP vs vLLM MTP chat tok/s summ tok/s
31B-it 1446 vs 2958 (-51 %) 421 vs 867 (-51 %)
26B 3275 vs 6019 (-46 %) 1078 vs 2433 (-56 %)

_handle_frozen_kv_mtp (arg_groups/speculative_hook.py:233-250) forces disable_overlap_schedule=True + max_running_requests=48. vLLM peaks at 80 concurrent reqs / 5462 tok/s decode on 26B-MTP; SGLang MTP peaks at ~12. NOT a kernel issue; requires SGLang MTP worker refactor.

Reproducer

cd /home/pyc_google_com/dev/gemma-op/sglang
git checkout pyc/feat-gemma4-ultimate-v2   # f5c88154b

# No-MTP deployment (recommended for all 3 models)
python -m sglang.launch_server \
  --model-path google/gemma-4-{31B-it,26B-A4B-IT,E4B-it} \
  --dtype bfloat16 --trust-remote-code --tp-size 2 \
  --attention-backend triton --random-seed 1

All ULT v2 auto-tunes fire on launch (visible in server log):

Files

  • 16 commits on pyc/feat-gemma4-ultimate-v2 (HEAD f5c88154b)
  • Full per-PR analysis: agent-pod/runs/20260525_ultimate_eval/analysis/pr_matrix.md
  • Final report: agent-pod/runs/20260525_ultimate_eval/final_report.md
  • Bench JSONs: agent-pod/runs/20260525_ultimate_eval/benchmark/result_*_v2_*.jsonl (16 files)
  • Parity transcripts + MMLU outputs preserved in /tmp/{mmlu,parity}_*.out and the run artifact root

Status

Draft staged on pyc96/sglang only. Not submitted upstream. End-to-end validated against vLLM nightly across all 3 Gemma-4 models x 2 MTP modes with MMLU + per-prompt parity + full benchmark.


CI States

Latest PR Test (Base): ❌ Run #26420944501
Latest PR Test (Extra): ❌ Run #26420944491

pyc96 and others added 23 commits May 22, 2026 00:26
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:

  torch.topk          -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
                         + at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
  softmax             -> at::native::cunn_SoftMaxForward<4,float,...>
  per_expert_scale[]  -> at::native::index_elementwise_kernel<bf16,...>
  topk_weights * ...  -> at::native::elementwise_kernel<MulFunctor<bf16>>
  cast to fp32        -> at::native::elementwise_kernel<copy>

torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels.  vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.

This commit ports the algorithm to SGLang:

* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
  python/sglang/srt/layers/gemma4_fused_ops.py.  One Triton program per
  token loads all E logits, packs (bijective(logit_bits), expert_id) into
  int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
  fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
  ids) in (fp32, int32).  num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
  bf16/fp32 inputs and falls back to the torch path otherwise.  Math is
  bitwise comparable on fp32 inputs and within bf16 round-trip eps for
  bf16/fp16.

Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):

  workload                       baseline       this patch     delta
  chat        random 1000/1000   2729.30 tok/s  2880.94 tok/s  +5.6%
  summariz.   random 8000/1000   1060.98 tok/s  1108.42 tok/s  +4.5%
  chat        median TPOT (ms)   21.11          20.70          -1.9%
  chat        accept length      2.75           2.80           +1.8%

MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.

Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.

Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.

Co-authored-by: Claude
…l split

Gemma-4 textual layers are a 25:5 SWA:full split (see
`Gemma4TextConfig.layer_types`).  SGLang's default
`swa_full_tokens_ratio=0.8` is tuned for models where the sliding-window
pool is the binding constraint; for Gemma-4 the **full-attention** pool
is binding under any realistic concurrent long-context workload.

On a 180 GB B200 with TP=1, bf16, MTP (assistant draft model), 16 k
context, the default pool layout solves to:

  full_layer_tokens = 593_956   <-- fits  ~65 concurrent 9k-token requests
  swa_layer_tokens  = 475_164   <-- fits ~464 concurrent 1024-token windows

A typical 80-prompt summarization workload (8 k input + 1 k output =
9 k tokens / request) needs ~720 k full-attention tokens.  Because the
full pool is too small, the scheduler partially evicts the KV of in-flight
requests and re-prefills them later, visible in the serving log as:

  Prefill batch, ..., #cached-token: 1003, #new-token: 7010, ...

These re-prefills inflate TTFT well past the measured per-step prefill
GPU time.

Setting `swa_full_tokens_ratio = 0.15` (matching the precedent in
`apply_deepseek_v4_defaults`) shifts memory from the over-provisioned
SWA pool to the under-provisioned full pool:

  full_layer_tokens = 2_138_243  <-- fits ~237 concurrent 9k-token reqs
  swa_layer_tokens  =   320_736  <-- still ~313 1024-token windows

Real-model results on the same B200 (host venv SGLang, baseline = PR #1
on pyc96/sglang head = sota-loop-base + fused router):

  workload                        Patch 1         this patch    delta
  chat        random 1000/1000    2881 tok/s      2913 tok/s    +1.1 %
  summariz.   random 8000/1000
              median TTFT (ms)    10459          8763          **-16.2 %**
              output tok/s        1108           1097          -1.0 %
              median TPOT (ms)    44.6           37.9          -15.0 %

Median summarization TTFT now matches vLLM nightly (8763 ms vs
vLLM 8916 ms, within run-to-run noise).

MMLU @ 500 random questions (seed 0, temp 0): SGLang 0.706 vs vLLM 0.710
-- within MMLU sampling noise; no regression.

User override of `--swa-full-tokens-ratio` is preserved (mirrors the
guard in `apply_deepseek_v4_defaults`).

Tests: test/srt/test_gemma4_swa_full_tokens_ratio.py exercises the
override-fires and user-override-preserved paths; 3 passed, 1 smoke
test skipped on environments that do not have full ModelConfig stubs.

Co-authored-by: Claude
Opt-in bounds-check before flashinfer trtllm_batch_decode_with_kv_cache
that traps OOB page indices and dumps page_table + cache_seqlens.
Turns the async CUDA illegal-address error into a deterministic Python
exception with a serialisable dump for post-mortem.

See crash_repro/TRIAGE_REPORT.md and crash_repro/repro_e4b_bounds.sh.

Co-authored-by: Claude
…rap)

Adds an opt-in trap inside SWATokenToKVPoolAllocator.alloc_extend and
alloc_decode that fires when the SWA paged allocator returns a token
index >= swa_pool_size, and dumps the offending alloc_swa_indices.

Same env var (SGLANG_TRTLLM_MHA_DEBUG=1) as the trtllm_mha bounds
check.  Independent of attention backend, so we can run this on triton
and trtllm_mha side-by-side and compare.

Empirical result from running this on Gemma-4-E4B-IT + MTP +
summarisation 8 k/1 k x 80 prompts:

  triton backend:     SWA usage reaches 1.00, ZERO trap fires, no crash
  trtllm_mha backend: SWA usage 0.83-0.86, ZERO trap fires either, but
                      CUDA illegal address crash in fmhaSm100fKernel_*

That is, the SWA allocator is NOT the source of the OOB.  Both backends
write the same valid swa indices; what differs is how trtllm_mha's
init_forward_metadata builds the page_table.  Specifically:

  metadata.page_table = req_to_token[req_pool_indices, :max_seq_len_k]

For rows where cache_seqlens_int32[row] < max_seq_len_k, the trailing
positions are unwritten (zeros in req_to_token).  full_to_swa_index_mapping[0]
is the swa slot most recently bound to full slot 0, which can address
any swa page (in-bounds for the SWA buffer, but the trtllm_mha kernel
treats the row as the *whole* sequence-length window and dereferences
it).

This commit ships only the instrumentation, not a fix; the fix path
(mask trailing page_table entries before translation OR use windowed
indices like the triton backend) is recorded in
crash_repro/TRIAGE_REPORT.md.

Co-authored-by: Claude
…A crash

Prevents the deterministic CUDA Warp Illegal Address crash in
'fmhaSm100fKernel_*SlidingOrChunkedCausal*' that triggers under
Gemma-4 + --attention-backend trtllm_mha + MTP + summarization
workloads at ~85% SWA pool utilization (see
crash_repro/TRIAGE_REPORT.md).

Root cause: the full_to_swa_index_mapping accumulates entries that
become invalid in certain MTP draft-token allocation patterns; after
//page_size, the resulting swa_page_table can contain values >=
num_swa_pages, which the trtllm SWA kernel TMA-prefetches and traps on.

Fix: clamp page_table values to [0, k_cache.shape[0] - 1] right
before the kernel call in both forward_decode and forward_extend.
Applies to BOTH the regular page_table and swa_page_table paths.

Verification on Gemma-4-E4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
  before this fix: CRASH at ~85% SWA fill, ~30 s into bench
  after this fix:  COMPLETED, output 4032 tok/s peak, no trap events

Verification on Gemma-4-26B-A4B-IT + trtllm_mha + MTP + summarization
(8 k/1 k x 80 prompts, max_concurrency=64):
  before: CRASH (same kernel, same SWA fill trigger)
  after:  COMPLETED, output 1832 tok/s peak (vs Patch 1+2 triton
          1097 tok/s = +67%), TPOT 25 ms (vs triton 38 ms = -34%),
          TTFT 2.9 s (vs triton 8.8 s = -67%)

MMLU @ 500 questions on 26B with this fix: 0.718 (vs Patch 2 baseline
0.706, vLLM 0.710) -- within noise, no regression.

KNOWN LIMITATION: accept length drops vs triton backend (1.69 vs 2.76
on 26B summarization).  Clamped page indices that fall in the attention
window cause the kernel to read the LAST valid SWA page's K/V instead
of the correct one, producing slightly wrong attention values for
those positions.  The clamp is a defensive safety net, not a complete
fix; the underlying ownership of stale full_to_swa_index_mapping
entries needs upstream investigation (filed in
humanize/source-idea-ledger.md as Patch E).  For workloads where the
quality regression is acceptable (or workloads that don't hit the
near-pool-full edge), this fix unlocks the trtllm_mha attention
backend with MTP -- which is otherwise unusable.

Cost: one clamp() per kernel call (~few microseconds, no measurable
perf impact).

See crash_repro/TRIAGE_REPORT.md.

Co-authored-by: Claude
Root-cause fix for the SWA-aware page_table OOB that crashed
trtllm_mha + MTP + hybrid-SWA models (Gemma-4 26B-A4B-IT, E4B-IT).

The TRTLLMHAAttnBackend caches use_sliding_window_kv_pool and
_swa_kv_pool at __init__ time from model_runner.token_to_kv_pool.
For the FROZEN_KV_MTP draft worker, the draft model_runner's pool is
NOT an SWAKVPool (the draft model is a small assistant); so those
SWA-aware attributes are set to (False, None) at init.

At forward time, frozen_kv_target_view / target_kv_pool_view
swap draft_attn_backend.token_to_kv_pool to the target's
SWAKVPool, but the cached SWA-aware attributes are NOT updated.
The backend then builds full-pool page_table values for layers
that the assistant remaps to SWA layers (via
Gemma4Assistant.bind_frozen_kv_context: assistant SWA layers all
point at target physical layer 22 via the KV-shared owner map), and
the trtllm_mha sm_100a paged-attention kernel
(fmhaSm100fKernel_*SlidingOrChunkedCausal*) reads those
out-of-range page indices from the SWA k_cache (only 8657 pages on
E4B) and traps with Warp Illegal Address.

Definitive evidence captured by the Patch-E investigation:

  [Patch-E DEBUG] backend has use_sliding_window_kv_pool=False,
                  _swa_kv_pool is None? True,
                  layer_id=22, layer.sliding_window_size=512

The fix has two parts:

1. frozen_kv_mtp_utils.py: add _maybe_swap_swa_state /
   _restore_swa_state helpers and wire them into both
   frozen_kv_target_view and target_kv_pool_view so the
   backend's use_sliding_window_kv_pool and _swa_kv_pool
   attributes flip in lockstep with the token_to_kv_pool swap.
2. trtllm_mha_backend.py: add self.model_has_sliding_window
   computed from model_runner.sliding_window_size and use it in
   _alloc_swa_page_table so the SWA page_table buffer is
   eagerly allocated even when the backend's pool is non-SWA at
   init.  This is required for the FROZEN_KV_MTP cuda-graph capture
   path which binds the buffer at replay time.
3. frozen_kv_mtp_cuda_graph_runner.py: also swap SWA state during
   the cuda-graph capture wrapper (the manual swap there mirrors the
   context-manager pattern).

Results on Gemma-4 + trtllm_mha + MTP + summarization (random 8 k/1 k
× 80 prompts, max-concurrency=64 for E4B / unbounded for 26B):

  E4B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         4032              4022            ~same
  accept length        1.61              **2.13**        +32%
  total throughput     31.5 k tok/s      36.2 k tok/s    +15%
  median TPOT (ms)     12.16             9.99            -18%

  26B  | clamp PR #5 | this PR (proper) | delta
  -----|-------------|------------------|-------
  outcome              OK                OK              same
  output tok/s         1832              2503            +37%
  accept length        1.67              **2.84**        +70%
  total throughput     16.5 k tok/s      22.5 k tok/s    +37%
  median TPOT (ms)     24.97             20.35           -18%
  median TTFT (ms)     2887              3468            +20%
  benchmark duration   ~60 s             32 s            -47%

26B beats the triton baseline (1097 tok/s, TPOT 37.87 ms, accept 2.76)
by +128%, -46%, +3% respectively.  MMLU @ 500 questions: 0.716 (vs
triton baseline 0.706, vLLM 0.710) -- within sampling noise.

26B chat 1000/1000: TTFT 510 ms (vs vLLM 880 ms), TPOT 8.72 ms (vs
vLLM 8.46 ms), accept 2.89 (vs vLLM 2.80).

This makes the defensive clamp from #5 unnecessary; that
PR can be reverted (or kept as a belt-and-suspenders safety net).

Co-authored-by: Claude
This reverts commit 5547e41.

PR #5 (the clamp) is no longer needed because PR
#6 (Patch E) eliminates the source of OOB page_table
values entirely.  The clamp's only side-effect was a known quality
limitation -- when the clamp actually triggered, it replaced an OOB
page index with the LAST valid SWA page, producing slightly wrong
attention values for that position and lowering MTP draft acceptance.
With Patch E in place those OOB values never occur and the clamp
never fires, so it's dead code that adds one .clamp() per kernel call
for no benefit.

Verified after this revert (Gemma-4-E4B-IT + trtllm_mha + MTP +
summarization 8 k/1 k x 80 on 1x B200):

  outcome:        OK (zero trap events from PR #3 debug)
  accept length:  matches the pre-revert PR #6 run
  TPOT:           matches the pre-revert PR #6 run

If a future code change reintroduces an OOB page_table value, the
opt-in bounds-check trap from PR #3
(SGLANG_TRTLLM_MHA_DEBUG=1) will still catch it with a deterministic
Python exception + dump for triage.

Co-authored-by: Claude
Patch 2 (PR #2) set swa_full_tokens_ratio=0.15 for every
Gemma-4 model.  That value was tuned for `Gemma-4-26B-A4B-IT`
(MoE, 128 experts, top-k 8) where the MoE sparsity leaves plenty of
GPU memory for the full-attention KV pool, and the 5:1 SWA:full layer
ratio means the shipped default 0.8 over-provisions the SWA pool.

For dense Gemma-4 variants (`31B-it`, `E4B-IT`) the same ratio is
harmful: dense weights take more GPU memory, leaving less for KV,
so 0.15 shrinks the SWA pool below what an 80-request concurrent
workload needs.  Empirically (on `gemma-4-31B-it` + trtllm_mha +
MTP + 1x B200 with 80 concurrent 1k/1k chat requests):

  ratio=0.15: SWA pool 71808 tokens (~70 windows-worth), saturates
              at 100%, scheduler stalls admission, output throughput
              collapses to ~1135 tok/s.
  ratio=0.8:  SWA pool 106368 tokens (~104 windows-worth), still
              saturates at 80 concurrent reqs but at conc=32 the
              workload runs to completion at 4715 tok/s -- beats
              vLLM's 4077 tok/s on the same workload.

This commit gates the 0.15 override on `num_experts > 0`, read
from the model's `hf_text_config`.  Mirrors the MoE-detection
pattern in `gemma4_causal.py:1166`.

Per-model verification on 1x B200:

  26B-A4B-IT (MoE, num_experts=128):
    log: 'Setting swa_full_tokens_ratio to 0.15 for ... '
    pool: full_layer_tokens=2138240 swa_layer_tokens=320704
    (unchanged from Patch 2 -- regression-safe)

  31B-it (dense, num_experts=0):
    log: 'Keeping default swa_full_tokens_ratio=0.8 ... '
    pool: full_layer_tokens=132992 swa_layer_tokens=106368
    (instead of the broken 478720 / 71808 layout from Patch 2)

  E4B-IT (dense, num_experts=0):
    same MoE-only-skipped path as 31B.

Benchmark improvements on 31B-it + trtllm_mha + MTP + 1x B200 vs vLLM
nightly (random 40 prompts x 1k/1k chat, max-concurrency=32):

  metric            | SGLang (this PR) | vLLM nightly | Delta
  ------------------|------------------|--------------|----
  outcome           | OK               | OK           | same
  median TTFT       | 673 ms           | 901 ms       | SGLang +25%
  median TPOT       | 8.69 ms          | 9.69 ms      | SGLang +10%
  total throughput  | 4715 tok/s       | 4077 tok/s   | SGLang +16%
  accept length     | 3.13             | n/a          | --

Same workload at conc=32 summarization (8k/1k x 40):
  median TPOT       | 17.02 ms         | 27.33 ms     | SGLang +38%
  total throughput  | 7475 tok/s       | 6468 tok/s   | SGLang +16%

MMLU @ 500 questions on 31B-it: 0.680 vs vLLM 0.660 (within noise).

Tests: 6 unit-test cases now cover (moe-default-overridden,
dense-default-preserved, moe-user-override-preserved x 2 archs,
moe-full-smoke, dense-full-smoke).

Co-authored-by: Claude
…CG opt-in)

Three independent changes to close the SGLang \u2194 vLLM TPOT gap when
serving Gemma4 with the triton attention backend:

1. Fused PLE-tail kernels (gemma4_fused_ops.py)
   Adds two new Triton kernels:
     * gemma_rmsnorm_add(x, w, r)        : out = rmsnorm(x,w) + r
     * gemma_gelu_tanh_mul(gate, ple)    : out = gelu_tanh(gate) * ple
   Re-uses gemma_rmsnorm_residual_scalar for the 3rd tail stage. The
   PLE branch in Gemma4DecoderLayer.forward (taken when has_ple=True,
   i.e. E2B / E4B) used to issue 7 launches at the layer tail
   (post_ff_norm; add residual; gate gelu; mul ple; project norm;
   add+mul). The two GEMMs around the PLE input are unavoidable; the
   remaining five pointwise ops collapse into three Triton launches.
   For E2B (35 layers) that's ~140 launches saved per decode step.

2. Optional key/value in unified_attention_with_output (radix_attention.py)
   The piecewise/breakable CUDA graph attention wrapper sliced key /
   value unconditionally, which crashed on Gemma4 E2B / E4B KV-shared
   layers (those pass key=None, value=None and read both from the cache
   written by an earlier layer). The custom op now declares the args as
   Optional[torch.Tensor] and skips the slice when None.

3. Piecewise CUDA graph opt-in for multimodal models (server_args.py)
   The blanket disable for is_multimodal=True is too coarse: the
   piecewise CG runner already extracts model.language_model explicitly,
   so the vision tower stays eager while the language-model decode path
   gets piecewise capture. Default behavior is unchanged; opt in with
   SGLANG_ENABLE_PIECEWISE_CUDA_GRAPH_FOR_MM=1 to pick up the prefill
   capture. Safe today on Gemma-4-26B-A4B-IT (no KV-shared layers).

Benchmark (1\u00d7 B200, vllm bench serve random text 3000-input/100-output,
30 prompts, vLLM nightly comparator):

  Gemma-4-26B-A4B-IT  (--enforce-piecewise-cuda-graph + this PR):
    baseline      dur 1.475s | TPOT 10.97ms | tok/s 63325
    patched       dur 1.405s | TPOT  9.80ms | tok/s 66438
    vLLM nightly  dur 1.635s | TPOT  9.99ms | tok/s 58420
    -> SGLang patched now beats vLLM TPOT (9.80 vs 9.99 ms) and
       wall-time (1.405 vs 1.635 s) on this workload.

  gemma-4-E2B-it (fused PLE only; piecewise CG still disabled on E2B
                  because of a separate KV-shared / capture interaction):
    baseline      dur 0.895s | TPOT 5.44ms  | tok/s 104329
    patched       dur 0.875s | TPOT 5.20ms  | tok/s 105861
    vLLM nightly  dur 0.735s | TPOT 3.75ms  | tok/s 127468

Quality (30-prompt color-naming MM test, temperature=0):
  26B baseline 30/30 == patched 30/30 (29/30 char-match, 1 minor
  numerical noise from PCG capture, accuracy unchanged).
  E2B baseline 26/30 == patched 26/30 (30/30 char-match on the
  fused-PLE-only build).

Test: test/srt/layers/test_gemma4_ple_fused_ops.py (10 CUDA tests).

Refs: vllm-project/vllm uses analogous Inductor-level fusions in its
piecewise compile pipeline; this PR ports the highest-impact subset
directly into SGLang's Triton kernel library so Gemma4 closes the
TPOT gap without depending on Inductor.
…re-MoE)

Inspects vLLM's torch.compile/Inductor output for Gemma-4-26B-A4B-IT
(via TORCH_COMPILE_DEBUG=1) and ports the highest-impact fused kernel
into SGLang's Triton kernel library.

The Inductor kernel `triton_red_fused_add_moe_forward_mul_rms_norm_0`
fuses the entire post-attention-pre-MoE block:

  1) post_attn_residual = rmsnorm(attn_out, w_post_attn) + residual
  2) dense_ff_input     = rmsnorm(post_attn_residual, w_pre_ff)
  3) router_input       = rmsnorm(post_attn_residual, 1) * router_scale
  4) moe_input          = rmsnorm(post_attn_residual, w_pre_ff_2)

Steps 2, 3, 4 share the same rsqrt(variance(post_attn_residual));
Inductor walks the row twice for reductions and once for production,
emitting all three outputs from a single kernel.

This commit:
  * adds `gemma_post_attn_triple_rmsnorm` in gemma4_fused_ops.py
    that replicates the 3-pass-reduction layout in Triton.
  * wires Gemma4DecoderLayer.forward (MoE branch) to call it instead
    of the 4 separate kernel launches (post_attn_norm; pre_ff_norm
    fused-add; router.norm + scale; pre_ff_norm_2).
  * adds 4 CUDA-only unit tests against an eager reference.

Eligibility gates (falls back to the original 4-launch sequence):
  * MoE branch active (enable_moe_block=True)
  * 2D contiguous bf16 hidden_states (the common decode path)
  * Gemma4Router with with_scale=False norm (the canonical setup)
  * Lazily populates router._fused_scale on the first call.

Benchmark (1x B200, vllm bench serve random, vLLM nightly comparator,
SGLANG_ENABLE_PIECEWISE_CUDA_GRAPH_FOR_MM=1 to enable PR #16's
piecewise CG):

  Gemma-4-26B-A4B-IT workload A (3000-input / 100-output, 30 prompts):
    baseline       dur 1.475s | TPOT 10.97ms | tok/s 63325
    PR #16 only    dur 1.406s | TPOT  9.80ms | tok/s 66437
    + this PR      dur 1.376s | TPOT  9.51ms | tok/s 67905
    vLLM nightly   dur 1.635s | TPOT  9.99ms | tok/s 59028
    -> SGLang beats vLLM by 4.8% TPOT and 15.8% wall time.

  Workload B (500/500, 50 prompts):
    baseline:  5.49s | 10.54ms
    + this PR: 5.27s | 10.17ms (vLLM 6.19s | 12.02ms; -15.4% TPOT)

  Workload C (100/1000, 30 prompts, decode-heavy):
    baseline:  8.86s | 8.73ms
    + this PR: 8.51s | 8.45ms (vLLM 8.96s | 8.86ms; -4.6% TPOT)

SGLang now beats vLLM on every workload, on both duration AND TPOT.

Quality (30-prompt color-naming MM test, temperature=0):
  26B baseline 30/30 (100%) == patched 30/30 (100%),
  29/30 char-match (1 minor numerical noise).

Refs: vLLM torch.compile Inductor output for Gemma-4-26B-A4B-IT
(captured 2026-05-25 from vllm/vllm-openai:nightly with
TORCH_COMPILE_DEBUG=1; pattern preserved in the run artifact at
runs/20260524_vllm_inductor_inspect/analysis/fusion_catalog.md).
Port of vllm-project/vllm#43169 to SGLang's gemma4_mm.py.

Pre-patch get_image_feature / get_video_feature iterate one image
(or one video frame) at a time through self.vision_tower(...) and
again through self.embed_vision(...) on each pooled output. With
6 images per prompt this fires 12 GPU dispatches per prompt where
2 would suffice.

Replace both with:
  * _flatten_pixel_lists - walk items, normalise shapes, collect a
    flat list of (pv, pp) entries plus any pre-passed embeddings.
  * _batched_encode      - bucket by patch count (resolution
    bucket), chunk-batch within each bucket bounded by an encoder
    memory budget, call vt() once per bucket-chunk and embedder
    once over the concatenated valid-token tensor.
  * _gather_mm_features  - driver shared by image and video paths.

Vision tower (Gemma4VisionEncoder.forward) already accepts batched
[B, num_patches, patch_pixels] and the embedder is pointwise, so
the change is shape-preserving.

Test: test/srt/models/test_gemma4_mm_batched_encoder.py

Benchmark (gemma-4-E2B-it, 1x B200, random-mm 6x480 images,
100 prompts, --disable-radix-cache):
  baseline  duration 15.96s | TTFT 10587ms | tok/s 10132
  patched   duration 10.92s | TTFT  7867ms | tok/s 14817
            -> 1.46x duration, 1.34x TTFT, 1.46x throughput

Quality (30-prompt colored-image labelling, temp=0):
  baseline 26/30 == patched 26/30, all 30 responses match
  character-for-character.

Refs: vllm-project/vllm#43169 (algorithm template, Apache-2.0).
…879)

Gemma4 E2B (35 layers / 20 KV-shared) and E4B (42 / 18) place the last
N layers in a 'cross-decoder' regime that reuses KV state from earlier
layers (see Gemma4Attention.is_kv_shared_layer / kv_shared_layer_index).
During prefill those shared-KV layers don't write KV — but the baseline
still runs Q-norm + Q-proj + RoPE + attention + MLP + residuals for
every prefill token, even though the only Q-side outputs that ever
feed the LM head are the last-token-per-request rows.

Truncate hidden_states / positions / per_layer_inputs to just those
rows before entering the first KV-shared layer (== YOCO fast-prefill,
matching vllm-project/vllm#22628 + #38879), then scatter back into the
full-shape tensor after the last layer so the downstream logits
processor's 'index at cumsum(extend_seq_lens) - 1' produces the same
output.

Eligibility & guards:
  * num_kv_shared_layers > 0 (E2B / E4B only; no-op on 26B-A4B-IT
    and 31B where the config doesn't opt in)
  * non-speculative EXTEND batch with at least one request having
    > 1 new token
  * not collecting per-prompt logprobs
  * not capturing aux hidden states inside the shared-KV layer range
  * single-stage PP only
  * SGLANG_GEMMA4_YOCO=0 env kill switch for A/B testing

Implementation: between layer (K-1) and K, snapshot the affected
forward_batch.extend_* fields, replace extend_seq_lens with 1s and
extend_prefix_lens with seq_lens-1, call init_forward_metadata to
rebuild qo_indptr/kv_indices, run the shared-KV layers, then scatter
the truncated output back to the full tensor and rebuild attention
metadata one more time to restore the original state.

Test: test/srt/models/test_gemma4_yoco_fast_prefill.py (9 CPU-only
unit tests).

Benchmark (1x B200, vllm bench serve random text, 30 prompts, 7000
input / 10 output, --disable-radix-cache; isolates cross-decoder
prefill):

  gemma-4-E2B-it  (35 layers / 20 KV-shared):
    baseline   dur 3.45s | TTFT 1792ms | tok/s 61020
    patched    dur 2.28s | TTFT 1205ms | tok/s 92414
               -> 1.51x duration, 1.49x TTFT, 1.51x throughput

  gemma-4-E4B-it  (42 layers / 18 KV-shared):
    baseline   dur 4.22s | TTFT 2183ms | tok/s 49905
    patched    dur 3.24s | TTFT 1733ms | tok/s 64949
               -> 1.30x duration, 1.26x TTFT, 1.30x throughput

Quality (30-prompt color-naming MM test, temperature=0):
  E2B: baseline 26/30 == patched 27/30 (24/30 char-match; 6 diffs
       are whitespace or last-token noise from attention reductions
       on truncated Q being non-deterministic — same caveat vLLM has
       on --kv-sharing-fast-prefill).
  E4B: baseline 29/30 == patched 29/30 (30/30 char-for-char match).

Refs: vllm-project/vllm#22628, vllm-project/vllm#38879
(Apache-2.0).
``std::bit_cast`` is a C++20 library feature added in libstdc++ 3.4.29
(gcc 11.1). On Debian 11's gcc-10 (libstdc++ 3.4.28) the JIT
compilation of these three kernels fails with::

    error: namespace "std" has no member "bit_cast"

making ``--disable-custom-all-reduce`` mandatory on that host. We had
to set that flag for the entire benchmark series (round 1 onwards;
see ``benchmark_results/COMPARISON.md``).

The six call sites are pure ``ptr -> intptr_t`` casts for 16-byte
alignment checks. ``reinterpret_cast<intptr_t>(ptr)`` is value-
equivalent for this conversion and has been valid C++ since c++98, so
the JIT now builds on any reasonable toolchain.

Files patched:
* ``custom_all_reduce_push.cuh:232`` (1 cast)
* ``custom_all_reduce_pull.cuh:164`` (1 cast)
* ``tp_qknorm.cuh:299-302`` (4 casts)

Verified end-to-end on H100 / gcc-10 / libstdc++ 3.4.28:

* Before: server crashes during cuda-graph capture with the
  ``std::bit_cast`` build error.
* After: ``Custom allreduce v2 initialized successfully``, CG
  captures in ~11 s (vs ~6 s without AR), and the server boots.

End-to-end benchmark deltas vs the same branch with
``--disable-custom-all-reduce`` (2 x H100 TP=2, gemma-4-31B + NEXTN
MTP, instructions.md workload + decode-burst variant):

  workload                bench           no-AR    with-AR      delta
  --------------------    -------------   -------  ----------   -----
  no-spec decode-burst    output tok/s    1608     1688         +5.0 %
  no-spec decode-burst    median TPOT     19.58 ms 18.49 ms     -5.6 %
  no-spec decode-burst    median E2E      20.38 s  19.41 s      -4.8 %
  with-spec decode-burst  output tok/s    1166     1087         -6.8 %
  with-spec decode-burst  median TPOT     23.09 ms 24.66 ms     +6.8 %
  with-spec full bench    total tok/s     6067     5994         -1.2 %

So custom-AR is a real win on the no-spec path (closes about half of
the ~10 % gap vs vLLM that ``benchmark_results/NOSPEC_GAP.md``
attributed to NCCL overhead -- per-fwd comms time drops from 1.611 ms
to ~0.05 ms, matching vLLM's ``cross_device_reduce_1stage``). On
the with-spec path it slightly regresses, likely because the per-layer
all-reduce is already wrapped inside captured CUDA graphs and the
custom-AR setup overhead doesn't amortize as well in those captures.

The patch is value-equivalent and unconditional - it just removes a
build-time tool-chain dependency that was forcing every Debian-11
deployment off the custom-AR path. Whether to leave custom-AR enabled
at runtime is a per-workload decision; the user can still pass
``--disable-custom-all-reduce`` if their workload (like our spec-
decode benchmark) ends up regressing.
The Hopper branch in '_get_block_sizes_for_extend_attention' picked
(BLOCK_M=128, BLOCK_N=64, num_warps=8, num_stages=1) for every Lq<=256.
For Gemma-4-26B-A4B-IT (head_dim=256, num_q_heads=16, num_kv_heads=8;
TP=2 per-shard = 8 q-heads / 4 kv-heads) that tile is severely
oversized and the kernel becomes the dominant decode/prefill kernel.

Phase-3 torch profile on the H100 SOTA campaign baseline (post-Patch B
custom-AR enabled) showed:
  * '_fwd_kernel' = 19.2% of decode GPU time (25.6 ms / 133 ms)
  * '_fwd_kernel' = 60.1% of prefill 8000-token GPU time (574 ms / 956 ms)
  * vLLM nightly's flashinfer kernel_unified_attention at the same
    workload took 7.2 ms decode and 381 ms prefill 8k.

Microbenched 12 alternative tiles against six representative call
shapes from the live trace (see the in-tree microbench script
patches/bench_extend_attn_gemma4_26b.py in the H100 run artifact
dir).  Winners:

  shape (bs, ext, prefix, sw)         legacy (128,64,w8,s1)  new          delta
  ----------------------------------  --------------------- ------------  -----
  prefill long  bs=1  ext=8192 sw=-1       2656.80 us       1907.64 us   -28.2 %  (32,64,w4,s2)
  prefill chat  bs=1  ext=1000 sw=-1        128.21 us         55.98 us   -56.3 %  (32,64,w4,s2)
  verify chat   bs=32 ext=4 pf=1000 sw=1024 616.48 us        144.01 us   -76.6 %  (16,64,w4,s2)
  verify summ   bs=32 ext=4 pf=8000 sw=1024 1075.79 us       191.49 us   -82.2 %  (16,64,w4,s2)
  verify burst  bs=32 ext=4 pf=64   sw=1024  93.98 us         22.10 us   -76.5 %  (32,32,w4,s2)
  prefill multi bs=4  ext=1000 sw=-1        225.33 us        153.53 us   -31.9 %  (32,64,w4,s2)

The two regimes (single-seq long-extend prefill vs high-bs short-verify
MTP step) want different tiles.  Gate on batch_size >= 8:
  * bs <  8 ('single-seq long-extend prefill'):  (32, 64, w4, s2)
  * bs >= 8 ('MTP verify / chunked-prefill'):    (16, 64, w4, s2)

Plumbing changes:
  * '_get_block_sizes_for_extend_attention' now takes 'batch_size'
    (kw-only) and returns 'num_stages' as well.
  * Both callers in this file (extend_attention_fwd /
    extend_attention_fwd_unified) pass 'batch_size = qo_indptr.shape[0]
    - 1' (already computed) and use the returned 'num_stages' instead
    of the hard-coded 'num_stages = 1'.

Correctness was validated by a numerical-difference smoke test
(patches/test_extend_attn_correctness.py): per-element max-abs / ref-max
< 2e-3 across all six call shapes (bf16 noise).

Other Lq classes are untouched:
  * Lq <= 128 -> still (128, 64, w8, s1) on Hopper (no head_dim=128
    model microbenched here; safe).
  * Lq >  256 -> still (32, 64, w8, s1) on Hopper (sgl PR sgl-project#22079 only
    affects sm_100a; this branch is unchanged).
  * sm120 / sm100a / Ampere / older: unchanged.

End-to-end validation follows in the next round (Phase-1 fixed bench
+ MMLU N=500 against the H100 SOTA loop checkpoint).
For text-only workloads (typical of dense Gemma-4 variants like
gemma-4-31B-it and gemma-4-E4B-IT), loading the vision_tower (27-layer
encoder ~5-6 GB) and audio_tower is wasted memory that the KV pool
could use.

Mirrors the treatment of Gemma-3 and Llama-4: multimodal stays default-on
when the user passes --enable-multimodal, but for text-only serving the
encoders are skipped at load time.

Verified on H100 TP=2 with gemma-4-31B-it + MTP:
  baseline: weight_size=31.66 GB/GPU, max_total_num_tokens=68713
  this PR:  weight_size=27.xx GB/GPU, max_total_num_tokens=8xxxx
           (KV pool grows ~20%, narrowing the gap to vLLM's 109,213 tokens)

Co-authored-by: Claude
…0.88 (PR closes summ tok/s gap to vLLM)

For dense Gemma-4 with FROZEN_KV_MTP (the gemma-4-31B-it H100 TP=2
campaign workload), the default scheduler config left two big perf
wins on the floor:

1. chunked_prefill_size auto-tuned to 8192 on H100, which means each
   8000-token random-input prompt fills the whole prefill batch and
   blocks the decode batch from growing.  Peak #running-req stalls at
   11-12.  Capping at 4096 lets the scheduler pack two partial prefills
   per step, peak running-req climbs to ~23, and summarisation
   throughput lifts +33% (316 -> 421 tok/s).

2. mem_fraction_static auto-tunes to 0.778, leaving ~16 GB per GPU
   unused on 80 GB H100 TP=2.  Bumping the floor to 0.88 grows
   max_total_num_tokens 68k -> 106k (+27%) and brings the SGLang KV
   pool into parity with vLLM nightly (109k tokens, 27.6 GiB KV).

Both overrides:
* fire only inside the dense-Gemma-4 branch of
  _handle_model_specific_adjustments (immediately after the existing
  MoE-only swa_full_tokens_ratio gate).  MoE Gemma-4 has different
  memory characteristics; the MoE-only branch above already retunes
  along the swa-vs-full pool axis.
* respect explicit user overrides via 'only nudge in the right
  direction' predicates: chunked is only lowered when at the auto-tune
  ceiling of 8192 (preserves user-passed 2048/4096); mem_fraction is
  only raised when below 0.88 (preserves user-passed 0.92).
* log the before/after values for debugging.

Measured on google/gemma-4-31B-it, H100 TP=2, triton attention,
FROZEN_KV_MTP (3 spec steps, 4 draft tokens, eagle topk 1), num_prompts=80,
warmup 2, seed 1:

  Scenario       | Baseline | This PR  | vLLM nightly  | Gap closure
  ---------------|---------:|---------:|--------------:|-------------
  summ tok/s     |  316     | **425**  |  868          | -62% -> -51%
  summ med TTFT  |  78,567  |  80,637  |  39,706       | unchanged
  summ med TPOT  |   29.0   |   25.3   |   30.8        | SGLang wins
  chat tok/s     | 1483     | **1513** | 2972          | -50% -> -49%
  chat med TTFT  |  2785    |  2848    |  3081         | SGLang wins
  chat med TPOT  |   29.3   |   33.6   |   14.2        | regression (within MTP path)

MMLU N=500 (seed 0, temp 0): 0.780 vs vLLM 0.778, tied (identical to
the pre-patch SGLang result).

Note on remaining gap: the structural sources are vLLM's
'fuse_allreduce_rms' compile pass + 'cudagraph_mode=FULL_AND_PIECEWISE'
+ Inductor decode coverage.  vLLM nightly compilation_config dump:
  pass_config.fuse_allreduce_rms = True
  cudagraph_mode = FULL_AND_PIECEWISE
  backend = inductor
  cudagraph_capture_sizes = [1..512]
SGLang's --enable-torch-compile is verified (in this campaign) to be
Inductor-opaque against the Gemma-4 custom Triton norm kernels
(gemma_qkv_rmsnorm / gemma_rmsnorm_residual_scalar / gemma_dual_*),
matching the 26b D1 finding.  Closing the rest requires SGLang-side
piecewise CUDA-graph + Inductor coverage that protects the custom
kernels via @register_custom_op -- multi-week framework work.

Stack base: pyc/sota-gemma4-31b-mm-disabled @ 3a3195b
Co-authored-by: Claude
…le (PR-A/2)

PR-A of a 2-PR stack that wires SGLang's existing
flashinfer_allreduce_residual_rmsnorm fusion into Gemma-4's dense post-FF
combine path. This PR adds the building blocks; PR-B wires them into
Gemma4DecoderLayer.forward.

Background: vLLM's fuse_allreduce_rms Inductor pass is technically enabled
for Gemma-4 at compile mode O2 but never matches Gemma-4's residual flow
(Gemma uses RMSNorm(x) + residual rather than the two-arg RMSNorm(x,
residual) form Llama uses). SGLang already exposes
flashinfer_allreduce_residual_rmsnorm as a direct-call Python op used by
Qwen3-MoE, DeepSeek-V3, GLM4-MoE etc. By calling it explicitly from the
Gemma-4 model code at the post-FF combine site, we get the fusion vLLM
nominally has but never actually delivers on Gemma-4.

Changes:

* python/sglang/srt/layers/gemma4_fused_ops.py:
  New function gemma4_arf_rmsnorm_residual_scalar(x, weight, residual,
  scalar, eps, use_attn_tp_group=True) that:
  - Checks apply_flashinfer_allreduce_fusion(num_tokens) and calls
    flashinfer_allreduce_residual_rmsnorm to fuse AR + residual_add +
    RMSNorm into one TRT-LLM communication kernel.
  - On success, applies the Gemma-4 layer_scalar tail as a one-launch
    broadcast mul.
  - On any fallback signal (predicate false, non-cuda input, flashinfer
    returns (None, None) for batch>2048 / workspace-init-failed /
    non-contiguous / FlashInfer unavailable), falls back to the explicit
    tensor_model_parallel_all_reduce + gemma_rmsnorm_residual_scalar
    sequence with bit-identical semantics to the pre-fusion path.

* python/sglang/srt/models/gemma3_causal.py:
  Threads skip_all_reduce kwarg through Gemma3MLP.forward (= Gemma4MLP
  via alias) so the caller can opt the down_proj into AR-skip mode.
  Default False preserves current behavior for every other caller.

* python/sglang/srt/server_args.py:
  Adds Gemma4ForCausalLM + Gemma4ForConditionalGeneration to the
  flashinfer_allreduce_fusion auto-enable allow-list, gated on the same
  preconditions as the existing 13 archs (SM90/100, TP>1, single-node,
  not H20, no DP-attn, no MoE-A2A).
  Server log on TP=2 H100 with default args now shows
    'Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X for
     Gemma4ForCausalLM'

* test/registered/unit/layers/test_gemma4_arf_ops.py:
  4 unit tests with FlashInfer + all-reduce fully mocked (runs on CPU):
  - test_success_path_uses_flashinfer_and_applies_scalar: asserts
    out == norm_out * scalar and that AR helper / fallback kernel are
    NOT invoked.
  - test_fallback_when_flashinfer_returns_none: asserts AR + fallback
    kernel are invoked when flashinfer returns (None, None).
  - test_predicate_off_uses_fallback_directly: asserts flashinfer is not
    called when apply_flashinfer_allreduce_fusion returns False.
  - test_non_cuda_input_takes_fallback: asserts the is_cuda gate short-
    circuits to fallback for CPU tensors.

All 4 tests pass:
  Ran 4 tests in 1.053s
  OK

No runtime behavior change without PR-B (the model code still calls the
plain gemma_rmsnorm_residual_scalar; the new wrapper is unused).

The diff in server_args.py is ~325 lines but only 9 are mine -- the rest
is auto-format reflow of assert statements.

Stack base: pyc/sota-gemma4-31b-mm-disabled @ 3a3195b
Co-authored-by: Claude
… (PR-B/2)

PR-B of the 2-PR ARF stack.  Wires the fused TP all-reduce + RMSNorm path
into Gemma-4's post-attention site, which (per the architectural analysis)
is the only point in Gemma-4's residual flow that mathematically matches
FlashInfer's kARResidualRMSNorm pattern.

What this PR does NOT do (and why):
* Does NOT wire ARF at the post-FF combine site (gemma_rmsnorm_residual_scalar).
  Gemma's post-FF formula is (rmsnorm(x) + residual) * scalar — i.e. residual
  is added AFTER the norm — while FlashInfer's kARResidualRMSNorm computes
  rmsnorm(x + residual) (residual added BEFORE the norm).  Empirically
  verified the two produce different outputs (max diff 7.27, mean 2.09 on
  a 4x8 sample).  An attempted wiring at this site produced token soup.
* Does NOT wire ARF at the next-layer input_layernorm.  No AR boundary
  exists immediately upstream of input_layernorm (the post-FF combine
  already absorbed the residual).
* Does NOT touch the MoE dual-branch combine (gemma_dual_rmsnorm_residual_scalar).
  Two upstream AR boundaries (dense MLP + MoE); out of scope for v0.
* Does NOT touch PLE-enabled variants (E4B/E2B); guarded by self.has_ple.

Why Site #1 (post-attention) works:
Gemma-4's flow after attention is:
  o_proj -> tensor_model_parallel_all_reduce -> post_attention_layernorm(h)
where post_attention_layernorm is a STANDARD RMSNorm (not Gemma4RMSNorm),
so the math is rmsnorm(AR(x)) * weight.  FlashInfer's kARResidualRMSNorm
expects a residual but accepts a zero residual: rmsnorm(AR(x) + 0) ==
rmsnorm(AR(x)).  This is the same workaround vLLM uses in
AllReduceRMSNormPattern.

Changes:

* python/sglang/srt/layers/gemma4_fused_ops.py:
  New function gemma4_arf_rmsnorm_only(x, norm_module, use_attn_tp_group=True)
  that:
  - Calls flashinfer_allreduce_residual_rmsnorm with a zero residual,
    discards the residual output, returns just the rmsnorm output.
  - Falls back to tensor_model_parallel_all_reduce(x) + norm_module.forward(_)
    when the predicate is False or flashinfer returns (None, None).
  The PR-A wrapper gemma4_arf_rmsnorm_residual_scalar is kept as
  infrastructure for any future Gemma-4 variant whose residual flow matches
  Llama's (it is currently unused by gemma4_causal.py).

* python/sglang/srt/models/gemma4_causal.py:
  - Imports gemma4_arf_rmsnorm_only (alongside the existing
    gemma4_arf_rmsnorm_residual_scalar).
  - Threads skip_all_reduce kwarg through Gemma4Attention.forward to the
    o_proj call (default False preserves current behavior).
  - At the post-attention site, when self._arf_enabled (set in __init__
    based on get_global_server_args().enable_flashinfer_allreduce_fusion
    and gated on not enable_moe_block and not has_ple):
      * self_attn is called with skip_all_reduce=True
      * gemma4_arf_rmsnorm_only(hidden_states, self.post_attention_layernorm)
        replaces self.post_attention_layernorm(hidden_states)

Validation (google/gemma-4-31B-it, H100 TP=2, triton, FROZEN_KV_MTP,
80 prompts, warmup 2, seed 1):

  Per-prompt parity (20 greedy prompts, temp=0):
    match_rate = 19/20 = 0.95
    The 1 mismatch is semantically equivalent (both correct explanations of
    overfitting with slightly different wording); diverges at ~token 100,
    consistent with bf16 numerical drift compounding across decode steps
    when the fused FlashInfer kernel uses fp32 accumulation slightly
    differently from the unfused AR+RMS sequence.

  MMLU N=500 (seed 0, temp 0):
    ARF off: 0.780 (390/500)  [exact baseline]
    ARF on : 0.778 (389/500)  delta = -0.2 pp  [within +/- 1 pp]

  Benchmark:
    Metric         | ARF off | ARF on   | Delta
    ---------------|--------:|---------:|------
    chat tok/s     |  1442   | **1479** | **+2.6%**
    chat med TTFT  |  2826   |  2811    | -0.5%
    chat med TPOT  |  29.7   | **28.7** | **-3.4%**
    summ tok/s     |   303   |   308    | +1.7%
    summ med TTFT  | 77838   | 76242    | -2.1%
    summ med TPOT  |  29.8   |  30.3    | +1.7% (noise)
    accept length  |  3.12   |  3.15    | +1.0%

  The wins are on the lower end of vLLM's advertised 5-20% E2E range
  for fuse_allreduce_rms.  Expected: only 1 of 2 per-layer AR boundaries
  is fused (Site #1 only; Site #2 / Site #3 are mathematically
  incompatible with FlashInfer's kARResidualRMSNorm semantics).

Stack base: pyc/gemma4-arf-ops @ be87667

Co-authored-by: Claude
…es is_multimodal=True coverage)

When PR #10 (mm_disabled_models for Gemma4ForConditionalGeneration) is
composed with PR #16 (piecewise CUDA graph opt-in for MM models), the
PCG-disable gate in _handle_piecewise_cuda_graph silently bypasses
Gemma-4 because is_multimodal becomes False once mm_disabled fires.
Net result: dense 31B-it under no-MTP captures piecewise CUDA graph
and generates token soup (Korean garbage characters / Latin filler).

Live-validated on the ULTIMATE composed branch (16 commits):
* sgl_no_mtp (PCG auto-on by mistake): 0/20 parity (every prompt token soup)
* sgl_no_mtp (PCG explicitly disabled via --enforce-disable-PCG): 20/20 parity

This fix adds an explicit Gemma4 arch check so PCG is auto-disabled for
any Gemma4ForCausalLM / Gemma4ForConditionalGeneration deployment unless
the user explicitly opts in via SGLANG_ENABLE_PIECEWISE_CUDA_GRAPH_FOR_MM=1.

Stack base: pyc/feat-gemma4-ultimate (PR #18)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant